{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using Keras version: 2.2.4\n"
     ]
    }
   ],
   "source": [
    "import ktrain\n",
    "from ktrain import graph as gr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Node Classification in Graphs\n",
    "\n",
    "Consider a social network (e.g., Facebook, Linkedin, Twitter) where each node is a person and links represent friendships. Each node (or person) in the graph can be be described by various attributes such as their location, Alma mater, organizational memberships, gender, relationship status, children, etc.  Suppose we had the U.S. political affiliation (e.g., Democrat, Republican, Libertarian, Green Party) of only a small subset of nodes with the remaining nodes being unknown.  Here, node classification involves predicting the political affiliation of unknown nodes based only on the small subset of of nodes for which we know the political affiliation. \n",
    "\n",
    "Where as traditional tabular-based models (e.g., logistic regression, SVM) utilize only the node's attributes to predict a node's label, graph neural networks utilize both the node's attributes and the graph's structure. For instance, to predict the political affiliation of a person it is helpful to not only look at the person's attributes but the attributes of other people within the vicinity of this person in the social network. Birds of a feather typically flock together. By exploiting graph structure, graph neural networks require much less labeled ground truth than non-graph approaches.  For instance, in the example below, we will consider the labels of only a very small fraction of all nodes to build our model.\n",
    "\n",
    "## Hateful Twitter Users\n",
    "In this notebook, we will use *ktain* to perform node classification on a Twitter graph to predict hateful users. Each Twitter user is described by various attributes related to both their profile and their tweet behavior. Examples include number of tweets and retweets, status length, etc.  \n",
    "\n",
    "The dataset can be downloaded from [here](https://www.kaggle.com/manoelribeiro/hateful-users-on-twitter).\n",
    "\n",
    "For node classification, *ktrain* requires two files formatted in a specific way:\n",
    "- a CSV or tab-delimited file containing the links (or edges) in the graph.  Each row containing two node IDs representing an edge.\n",
    "- A CSV or tab-delimited file describing the attributes and label associated with each node.  The first column is the node ID and the last column should be the label or target (as string labels such as \"hate\" or \"normal\").  All other columns should contain numerical features and are assumed to be standardized or transformed as necessary.  If the last column representing the target has missing values, these are treated as a holdout set for which predictions can be made after training the model. The numeric feature columns should not have any missing values.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean and Prepare Data\n",
    "We must first transform the raw dataset into the file formats described above. We consider two files: `users.edges` which describes the graph structure and `users_neighborhood_anon.csv` which contains each node's label and attributes.   The file `users.edges` is the edge list and is already in the format expected by *ktrain* for the most part.  We must clean and prepare `users_neighborhood_anon.csv` into the format expected by *ktrain*. We will drop unused columns, normalize numeric attributes, re-order/transform the target column `hate` into an interpretable string label, and save the data as a tab-delimited file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>statuses_count</th>\n",
       "      <th>followers_count</th>\n",
       "      <th>followees_count</th>\n",
       "      <th>favorites_count</th>\n",
       "      <th>listed_count</th>\n",
       "      <th>negotiate_empath</th>\n",
       "      <th>vehicle_empath</th>\n",
       "      <th>science_empath</th>\n",
       "      <th>timidity_empath</th>\n",
       "      <th>gain_empath</th>\n",
       "      <th>...</th>\n",
       "      <th>tweet number</th>\n",
       "      <th>retweet number</th>\n",
       "      <th>quote number</th>\n",
       "      <th>status length</th>\n",
       "      <th>number urls</th>\n",
       "      <th>baddies</th>\n",
       "      <th>mentions</th>\n",
       "      <th>time_diff</th>\n",
       "      <th>time_diff_median</th>\n",
       "      <th>hate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.541150</td>\n",
       "      <td>0.046773</td>\n",
       "      <td>1.104767</td>\n",
       "      <td>1.869391</td>\n",
       "      <td>0.017835</td>\n",
       "      <td>-1.752256</td>\n",
       "      <td>0.164900</td>\n",
       "      <td>0.181173</td>\n",
       "      <td>0.875069</td>\n",
       "      <td>1.130523</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.049013</td>\n",
       "      <td>0.321929</td>\n",
       "      <td>-0.369992</td>\n",
       "      <td>-1.036127</td>\n",
       "      <td>-0.796091</td>\n",
       "      <td>0.047430</td>\n",
       "      <td>0.356495</td>\n",
       "      <td>-1.888186</td>\n",
       "      <td>-1.299249</td>\n",
       "      <td>normal</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-0.700240</td>\n",
       "      <td>0.772450</td>\n",
       "      <td>-0.526061</td>\n",
       "      <td>-1.434183</td>\n",
       "      <td>0.613187</td>\n",
       "      <td>-0.735320</td>\n",
       "      <td>-0.864337</td>\n",
       "      <td>0.599279</td>\n",
       "      <td>1.610977</td>\n",
       "      <td>-1.203049</td>\n",
       "      <td>...</td>\n",
       "      <td>1.479066</td>\n",
       "      <td>-1.999580</td>\n",
       "      <td>-1.545285</td>\n",
       "      <td>-0.188945</td>\n",
       "      <td>-1.875745</td>\n",
       "      <td>-0.626192</td>\n",
       "      <td>-1.972207</td>\n",
       "      <td>0.160925</td>\n",
       "      <td>-1.512603</td>\n",
       "      <td>unknown</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-1.077284</td>\n",
       "      <td>-0.127775</td>\n",
       "      <td>0.767345</td>\n",
       "      <td>-0.669050</td>\n",
       "      <td>-0.523882</td>\n",
       "      <td>-0.118440</td>\n",
       "      <td>-1.573040</td>\n",
       "      <td>1.211083</td>\n",
       "      <td>-0.154213</td>\n",
       "      <td>0.932754</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.201320</td>\n",
       "      <td>0.452537</td>\n",
       "      <td>-1.545285</td>\n",
       "      <td>0.637869</td>\n",
       "      <td>0.884530</td>\n",
       "      <td>-0.096918</td>\n",
       "      <td>0.348954</td>\n",
       "      <td>0.698841</td>\n",
       "      <td>0.122176</td>\n",
       "      <td>unknown</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.908494</td>\n",
       "      <td>-0.021575</td>\n",
       "      <td>-0.548705</td>\n",
       "      <td>0.078540</td>\n",
       "      <td>0.017835</td>\n",
       "      <td>-0.472125</td>\n",
       "      <td>1.281633</td>\n",
       "      <td>-0.544862</td>\n",
       "      <td>1.259492</td>\n",
       "      <td>-0.456470</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.018822</td>\n",
       "      <td>1.085858</td>\n",
       "      <td>-0.662393</td>\n",
       "      <td>-0.701835</td>\n",
       "      <td>0.088472</td>\n",
       "      <td>-0.626192</td>\n",
       "      <td>-1.254997</td>\n",
       "      <td>-1.576801</td>\n",
       "      <td>-1.311031</td>\n",
       "      <td>unknown</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.778589</td>\n",
       "      <td>0.729918</td>\n",
       "      <td>2.296049</td>\n",
       "      <td>-0.725089</td>\n",
       "      <td>0.700128</td>\n",
       "      <td>-1.488804</td>\n",
       "      <td>-1.573040</td>\n",
       "      <td>-0.969812</td>\n",
       "      <td>0.199834</td>\n",
       "      <td>-1.203049</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.427866</td>\n",
       "      <td>0.638106</td>\n",
       "      <td>-1.545285</td>\n",
       "      <td>1.370832</td>\n",
       "      <td>0.655433</td>\n",
       "      <td>0.955922</td>\n",
       "      <td>-1.914894</td>\n",
       "      <td>0.803553</td>\n",
       "      <td>1.472247</td>\n",
       "      <td>unknown</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 205 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   statuses_count  followers_count  followees_count  favorites_count  \\\n",
       "0        1.541150         0.046773         1.104767         1.869391   \n",
       "1       -0.700240         0.772450        -0.526061        -1.434183   \n",
       "2       -1.077284        -0.127775         0.767345        -0.669050   \n",
       "3        1.908494        -0.021575        -0.548705         0.078540   \n",
       "4       -0.778589         0.729918         2.296049        -0.725089   \n",
       "\n",
       "   listed_count  negotiate_empath  vehicle_empath  science_empath  \\\n",
       "0      0.017835         -1.752256        0.164900        0.181173   \n",
       "1      0.613187         -0.735320       -0.864337        0.599279   \n",
       "2     -0.523882         -0.118440       -1.573040        1.211083   \n",
       "3      0.017835         -0.472125        1.281633       -0.544862   \n",
       "4      0.700128         -1.488804       -1.573040       -0.969812   \n",
       "\n",
       "   timidity_empath  gain_empath  ...  tweet number  retweet number  \\\n",
       "0         0.875069     1.130523  ...     -0.049013        0.321929   \n",
       "1         1.610977    -1.203049  ...      1.479066       -1.999580   \n",
       "2        -0.154213     0.932754  ...     -0.201320        0.452537   \n",
       "3         1.259492    -0.456470  ...     -1.018822        1.085858   \n",
       "4         0.199834    -1.203049  ...     -0.427866        0.638106   \n",
       "\n",
       "   quote number  status length  number urls   baddies  mentions  time_diff  \\\n",
       "0     -0.369992      -1.036127    -0.796091  0.047430  0.356495  -1.888186   \n",
       "1     -1.545285      -0.188945    -1.875745 -0.626192 -1.972207   0.160925   \n",
       "2     -1.545285       0.637869     0.884530 -0.096918  0.348954   0.698841   \n",
       "3     -0.662393      -0.701835     0.088472 -0.626192 -1.254997  -1.576801   \n",
       "4     -1.545285       1.370832     0.655433  0.955922 -1.914894   0.803553   \n",
       "\n",
       "   time_diff_median     hate  \n",
       "0         -1.299249   normal  \n",
       "1         -1.512603  unknown  \n",
       "2          0.122176  unknown  \n",
       "3         -1.311031  unknown  \n",
       "4          1.472247  unknown  \n",
       "\n",
       "[5 rows x 205 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# useful imports\n",
    "import sklearn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# read in data\n",
    "data_dir = 'data/hateful-twitter-users/'\n",
    "users_feat = pd.read_csv(os.path.join(data_dir, 'users_neighborhood_anon.csv'))\n",
    "\n",
    "# clean the data and drop unused columns\n",
    "def data_cleaning(feat):\n",
    "    feat = feat.drop(columns=[\"hate_neigh\", \"normal_neigh\"])\n",
    "    # Convert target values in hate column from strings to integers (0,1,2)\n",
    "    feat['hate'] = np.where(feat['hate']=='hateful', 1, np.where(feat['hate']=='normal', 0, 2))\n",
    "    # missing information\n",
    "    number_of_missing = feat.isnull().sum()\n",
    "    number_of_missing[number_of_missing!=0]\n",
    "    # Replace NA with 0\n",
    "    feat.fillna(0, inplace=True)\n",
    "    # droping info about suspension and deletion as it is should not be use din the predictive model\n",
    "    feat.drop(feat.columns[feat.columns.str.contains(\"is_\")], axis=1, inplace=True)\n",
    "    # drop glove features\n",
    "    feat.drop(feat.columns[feat.columns.str.contains(\"_glove\")], axis=1, inplace=True)\n",
    "    # drop c_ features\n",
    "    feat.drop(feat.columns[feat.columns.str.contains(\"c_\")], axis=1, inplace=True)\n",
    "    # drop sentiment features for now\n",
    "    feat.drop(feat.columns[feat.columns.str.contains(\"sentiment\")], axis=1, inplace=True)\n",
    "    # drop hashtag feature\n",
    "    feat.drop(['hashtags'], axis=1, inplace=True)\n",
    "    # Drop centrality based measures\n",
    "    feat.drop(columns=['betweenness', 'eigenvector', 'in_degree', 'out_degree'], inplace=True)\n",
    "    feat.drop(columns=['created_at'], inplace=True)\n",
    "    return feat\n",
    "node_data = data_cleaning(users_feat)\n",
    "\n",
    "# recode the target column into human-readable string labels\n",
    "node_data = node_data.replace({'hate': {0:'normal', 1:'hateful', 2:'unknown'}})\n",
    "\n",
    "# normalize the numeric columns (ignore columns user_id and hate which is label colum)\n",
    "df_values = node_data.iloc[:, 2:].values\n",
    "pt = sklearn.preprocessing.PowerTransformer(method='yeo-johnson', standardize=True)\n",
    "df_values_log = pt.fit_transform(df_values)\n",
    "node_data.iloc[:, 2:] = df_values_log\n",
    "\n",
    "# drop user_id and use the equivalent index as node ID\n",
    "node_data.index = node_data.index.map(str)\n",
    "node_data.drop(columns=['user_id'], inplace=True)\n",
    "\n",
    "# move target column to last position\n",
    "cols = list(node_data)\n",
    "cols.remove('hate')\n",
    "cols.append('hate')\n",
    "node_data = node_data.reindex(columns= cols)\n",
    "node_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save both data files as tab-delimited to be consistent\n",
    "node_data.to_csv('/tmp/twitter-nodes.tab', sep='\\t', header=False)\n",
    "edge_data = pd.read_csv(data_dir+'/users.edges', header=None, names=['Source', 'Destination'])\n",
    "edge_data.to_csv('/tmp/twitter-edges.tab', sep='\\t', header=False, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check to make sure there are no missing values in the non-target columns\n",
    "node_data[node_data.drop('hate', axis=1).isnull().any(axis=1)].shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### STEP 1: Load and Preprocess Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we will load the preprocessed dataset.  Of the nodes that are annotated with labels (i.e., hate vs normal), we will use 15% as the training set and 85% as for validation.  For the nodes with no labels (i.e., tagged as 'unknown' in this dataset), we will create `df_holdout` and `G_complete`.  The dataframe `df_holdout` contains the features of those nodes with no labels and `G_complete` is the entire graph including the nodes in `df_holdout` that were held out. \n",
    "\n",
    "If `holdout_for_inductive=True`, then features of holdout nodes are **not** visible during training, and the training graph is a subgraph of `G_complete`.  Otherwise, the features (not labels) of the heldout nodes can be exploited during training.  The `holdout_for_inductive=True` paramter is useful for assessing how well your model can make predictions for new nodes added to the graph later using `G_complete` (*inductive inference*). In this case, `holdout_for_inductive=False`, as we would like to use the features of unlabeled nodes to help learn to make accurate predictions.  `G_complete`, then, is identical to the training graph and is not used, as we are only doing *transductive inference*.  See this example notebook to better see the difference between *transductive* and *inductive* inference in graphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Largest subgraph statistics: 100386 nodes, 2194979 edges\n",
      "using 95415 nodes with missing target as holdout set\n",
      "Size of training graph: 100386 nodes\n",
      "Training nodes: 745\n",
      "Validation nodes: 4226\n",
      "Nodes treated as unlabeled for testing/inference: 95415\n",
      "Holdout node features are visible during training (transductive inference)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "(train_data, val_data, preproc, \n",
    " df_holdout, G_complete)        = gr.graph_nodes_from_csv('/tmp/twitter-nodes.tab',\n",
    "                                           '/tmp/twitter-edges.tab',\n",
    "                                           sample_size=20, \n",
    "                                           holdout_pct=None,  # using missing_label_value for holdout\n",
    "                                           holdout_for_inductive=False,\n",
    "                                           missing_label_value='unknown',\n",
    "                                           train_pct=0.15,\n",
    "                                           sep='\\t')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training graph and the dataframe containing features of all nodes in the training graph are both accessible via the `Preprocessor` instance.  Let's look at the class distributions in the training graph. There is class imbalance here that might be addressed by computing class weights and supplying them to the `class_weight` parameter of any `*fit*` method in *ktrain* and Keras.  We will train without doing so here, though."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial hateful/normal users distribution\n",
      "unknown    95415\n",
      "normal      4427\n",
      "hateful      544\n",
      "Name: target, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "print(\"Initial hateful/normal users distribution\")\n",
    "print(preproc.df.target.value_counts())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### STEP 2: Define a Model and Wrap in Learner Object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "graphsage: GraphSAGE:  http://arxiv.org/pdf/1607.01759.pdf\n"
     ]
    }
   ],
   "source": [
    "gr.print_node_classifiers()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is Multi-Label? False\n",
      "done\n"
     ]
    }
   ],
   "source": [
    "learner = ktrain.get_learner(model=gr.graph_node_classifier('graphsage', train_data), \n",
    "                             train_data=train_data, \n",
    "                             val_data=val_data, \n",
    "                             batch_size=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### STEP 3:  Estimate LR \n",
    "Given the small number of batches per epoch, a larger number of epochs is required to estimate the learning rate. We will cap it at 100 here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "simulating training for different learning rates... this may take a few moments...\n",
      "Epoch 1/100\n",
      "11/11 [==============================] - 1s 129ms/step - loss: 0.6290 - acc: 0.6650\n",
      "Epoch 2/100\n",
      "11/11 [==============================] - 1s 106ms/step - loss: 0.6213 - acc: 0.6693\n",
      "Epoch 3/100\n",
      "11/11 [==============================] - 1s 106ms/step - loss: 0.6265 - acc: 0.6871\n",
      "Epoch 4/100\n",
      "11/11 [==============================] - 1s 105ms/step - loss: 0.6252 - acc: 0.6963\n",
      "Epoch 5/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.6309 - acc: 0.6832\n",
      "Epoch 6/100\n",
      "11/11 [==============================] - 1s 104ms/step - loss: 0.6445 - acc: 0.6350\n",
      "Epoch 7/100\n",
      "11/11 [==============================] - 1s 109ms/step - loss: 0.6300 - acc: 0.6543\n",
      "Epoch 8/100\n",
      "11/11 [==============================] - 1s 109ms/step - loss: 0.6207 - acc: 0.6785\n",
      "Epoch 9/100\n",
      "11/11 [==============================] - 1s 98ms/step - loss: 0.6455 - acc: 0.6495\n",
      "Epoch 10/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.6296 - acc: 0.6577\n",
      "Epoch 11/100\n",
      "11/11 [==============================] - 1s 91ms/step - loss: 0.6227 - acc: 0.6906\n",
      "Epoch 12/100\n",
      "11/11 [==============================] - 1s 96ms/step - loss: 0.6279 - acc: 0.6828\n",
      "Epoch 13/100\n",
      "11/11 [==============================] - 1s 106ms/step - loss: 0.6233 - acc: 0.6757\n",
      "Epoch 14/100\n",
      "11/11 [==============================] - 1s 119ms/step - loss: 0.6258 - acc: 0.6884\n",
      "Epoch 15/100\n",
      "11/11 [==============================] - 1s 111ms/step - loss: 0.6312 - acc: 0.6757\n",
      "Epoch 16/100\n",
      "11/11 [==============================] - 1s 106ms/step - loss: 0.6386 - acc: 0.6600\n",
      "Epoch 17/100\n",
      "11/11 [==============================] - 1s 104ms/step - loss: 0.6318 - acc: 0.6729\n",
      "Epoch 18/100\n",
      "11/11 [==============================] - 1s 107ms/step - loss: 0.6249 - acc: 0.6985\n",
      "Epoch 19/100\n",
      "11/11 [==============================] - 1s 110ms/step - loss: 0.6316 - acc: 0.6656\n",
      "Epoch 20/100\n",
      "11/11 [==============================] - 1s 110ms/step - loss: 0.6154 - acc: 0.7031\n",
      "Epoch 21/100\n",
      "11/11 [==============================] - 1s 101ms/step - loss: 0.6404 - acc: 0.6308\n",
      "Epoch 22/100\n",
      "11/11 [==============================] - 1s 92ms/step - loss: 0.6310 - acc: 0.6750\n",
      "Epoch 23/100\n",
      "11/11 [==============================] - 1s 93ms/step - loss: 0.6150 - acc: 0.6814\n",
      "Epoch 24/100\n",
      "11/11 [==============================] - 1s 97ms/step - loss: 0.6292 - acc: 0.6629\n",
      "Epoch 25/100\n",
      "11/11 [==============================] - 1s 114ms/step - loss: 0.6173 - acc: 0.6776\n",
      "Epoch 26/100\n",
      "11/11 [==============================] - 1s 110ms/step - loss: 0.6191 - acc: 0.6942\n",
      "Epoch 27/100\n",
      "11/11 [==============================] - 1s 112ms/step - loss: 0.6225 - acc: 0.6771\n",
      "Epoch 28/100\n",
      "11/11 [==============================] - 1s 107ms/step - loss: 0.6172 - acc: 0.6895\n",
      "Epoch 29/100\n",
      "11/11 [==============================] - 1s 104ms/step - loss: 0.6170 - acc: 0.6776\n",
      "Epoch 30/100\n",
      "11/11 [==============================] - 1s 100ms/step - loss: 0.6127 - acc: 0.6952\n",
      "Epoch 31/100\n",
      "11/11 [==============================] - 1s 110ms/step - loss: 0.6057 - acc: 0.7107\n",
      "Epoch 32/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.6125 - acc: 0.7045\n",
      "Epoch 33/100\n",
      "11/11 [==============================] - 1s 101ms/step - loss: 0.5887 - acc: 0.7348\n",
      "Epoch 34/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.5973 - acc: 0.7348\n",
      "Epoch 35/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.5597 - acc: 0.7776\n",
      "Epoch 36/100\n",
      "11/11 [==============================] - 1s 100ms/step - loss: 0.5613 - acc: 0.7818\n",
      "Epoch 37/100\n",
      "11/11 [==============================] - 1s 109ms/step - loss: 0.5502 - acc: 0.7839\n",
      "Epoch 38/100\n",
      "11/11 [==============================] - 1s 104ms/step - loss: 0.5477 - acc: 0.7897\n",
      "Epoch 39/100\n",
      "11/11 [==============================] - 1s 110ms/step - loss: 0.5317 - acc: 0.8054 0s - loss: 0.5096 - acc: 0\n",
      "Epoch 40/100\n",
      "11/11 [==============================] - 1s 119ms/step - loss: 0.5094 - acc: 0.8511\n",
      "Epoch 41/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.5038 - acc: 0.8439\n",
      "Epoch 42/100\n",
      "11/11 [==============================] - 1s 106ms/step - loss: 0.4800 - acc: 0.8622\n",
      "Epoch 43/100\n",
      "11/11 [==============================] - 1s 106ms/step - loss: 0.4655 - acc: 0.8710\n",
      "Epoch 44/100\n",
      "11/11 [==============================] - 1s 109ms/step - loss: 0.4374 - acc: 0.8845\n",
      "Epoch 45/100\n",
      "11/11 [==============================] - 1s 111ms/step - loss: 0.4223 - acc: 0.8866\n",
      "Epoch 46/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.4025 - acc: 0.8824\n",
      "Epoch 47/100\n",
      "11/11 [==============================] - 1s 97ms/step - loss: 0.3779 - acc: 0.8945\n",
      "Epoch 48/100\n",
      "11/11 [==============================] - 1s 94ms/step - loss: 0.3659 - acc: 0.8888\n",
      "Epoch 49/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.3548 - acc: 0.8859\n",
      "Epoch 50/100\n",
      "11/11 [==============================] - 1s 111ms/step - loss: 0.3284 - acc: 0.8845\n",
      "Epoch 51/100\n",
      "11/11 [==============================] - 1s 110ms/step - loss: 0.3067 - acc: 0.8945\n",
      "Epoch 52/100\n",
      "11/11 [==============================] - 1s 121ms/step - loss: 0.2858 - acc: 0.8977\n",
      "Epoch 53/100\n",
      "11/11 [==============================] - 1s 110ms/step - loss: 0.2695 - acc: 0.8905\n",
      "Epoch 54/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.2523 - acc: 0.9077\n",
      "Epoch 55/100\n",
      "11/11 [==============================] - 1s 104ms/step - loss: 0.2539 - acc: 0.9237\n",
      "Epoch 56/100\n",
      "11/11 [==============================] - 1s 90ms/step - loss: 0.2307 - acc: 0.9185\n",
      "Epoch 57/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.2081 - acc: 0.9318\n",
      "Epoch 58/100\n",
      "11/11 [==============================] - 1s 104ms/step - loss: 0.2110 - acc: 0.9237\n",
      "Epoch 59/100\n",
      "11/11 [==============================] - 1s 101ms/step - loss: 0.1850 - acc: 0.9437\n",
      "Epoch 60/100\n",
      "11/11 [==============================] - 1s 96ms/step - loss: 0.1805 - acc: 0.9373\n",
      "Epoch 61/100\n",
      "11/11 [==============================] - 1s 108ms/step - loss: 0.1932 - acc: 0.9252\n",
      "Epoch 62/100\n",
      "11/11 [==============================] - 2s 141ms/step - loss: 0.1815 - acc: 0.9347\n",
      "Epoch 63/100\n",
      "11/11 [==============================] - 1s 110ms/step - loss: 0.1673 - acc: 0.9316\n",
      "Epoch 64/100\n",
      "11/11 [==============================] - 1s 98ms/step - loss: 0.1852 - acc: 0.9407\n",
      "Epoch 65/100\n",
      "11/11 [==============================] - 1s 109ms/step - loss: 0.1670 - acc: 0.9332\n",
      "Epoch 66/100\n",
      "11/11 [==============================] - 1s 106ms/step - loss: 0.1820 - acc: 0.9373\n",
      "Epoch 67/100\n",
      "11/11 [==============================] - 1s 96ms/step - loss: 0.1355 - acc: 0.9521\n",
      "Epoch 68/100\n",
      "11/11 [==============================] - 1s 115ms/step - loss: 0.1501 - acc: 0.9517\n",
      "Epoch 69/100\n",
      "11/11 [==============================] - 1s 106ms/step - loss: 0.1522 - acc: 0.9521\n",
      "Epoch 70/100\n",
      "11/11 [==============================] - 1s 109ms/step - loss: 0.1639 - acc: 0.9460\n",
      "Epoch 71/100\n",
      "11/11 [==============================] - 1s 101ms/step - loss: 0.1746 - acc: 0.9287\n",
      "Epoch 72/100\n",
      "11/11 [==============================] - 1s 121ms/step - loss: 0.1861 - acc: 0.9294\n",
      "Epoch 73/100\n",
      "11/11 [==============================] - 1s 115ms/step - loss: 0.1656 - acc: 0.9444\n",
      "Epoch 74/100\n",
      "11/11 [==============================] - 1s 110ms/step - loss: 0.1401 - acc: 0.9516\n",
      "Epoch 75/100\n",
      "11/11 [==============================] - 1s 103ms/step - loss: 0.1422 - acc: 0.9401\n",
      "Epoch 76/100\n",
      "11/11 [==============================] - 1s 100ms/step - loss: 0.1438 - acc: 0.9537\n",
      "Epoch 77/100\n",
      "11/11 [==============================] - 1s 110ms/step - loss: 0.1594 - acc: 0.9394\n",
      "Epoch 78/100\n",
      "11/11 [==============================] - 1s 105ms/step - loss: 0.1477 - acc: 0.9380\n",
      "Epoch 79/100\n",
      "11/11 [==============================] - 1s 116ms/step - loss: 0.1931 - acc: 0.9261\n",
      "Epoch 80/100\n",
      "11/11 [==============================] - 1s 111ms/step - loss: 0.1378 - acc: 0.9587\n",
      "Epoch 81/100\n",
      "11/11 [==============================] - 1s 105ms/step - loss: 0.1332 - acc: 0.9456\n",
      "Epoch 82/100\n",
      "11/11 [==============================] - 1s 109ms/step - loss: 0.1417 - acc: 0.9560\n",
      "Epoch 83/100\n",
      "11/11 [==============================] - 1s 98ms/step - loss: 0.1865 - acc: 0.9344\n",
      "Epoch 84/100\n",
      "11/11 [==============================] - 1s 99ms/step - loss: 0.1545 - acc: 0.9416\n",
      "Epoch 85/100\n",
      "11/11 [==============================] - 1s 113ms/step - loss: 0.1665 - acc: 0.9387\n",
      "Epoch 86/100\n",
      "11/11 [==============================] - 1s 102ms/step - loss: 0.1233 - acc: 0.9558\n",
      "Epoch 87/100\n",
      "11/11 [==============================] - 1s 109ms/step - loss: 0.2232 - acc: 0.9095\n",
      "Epoch 88/100\n",
      "11/11 [==============================] - 1s 99ms/step - loss: 0.1615 - acc: 0.9437\n",
      "Epoch 89/100\n",
      "11/11 [==============================] - 1s 103ms/step - loss: 0.1802 - acc: 0.9373\n",
      "Epoch 90/100\n",
      "11/11 [==============================] - 1s 109ms/step - loss: 0.1947 - acc: 0.9460\n",
      "Epoch 91/100\n",
      "11/11 [==============================] - 1s 104ms/step - loss: 0.2548 - acc: 0.8948\n",
      "Epoch 92/100\n",
      "11/11 [==============================] - 1s 104ms/step - loss: 0.1870 - acc: 0.9347\n",
      "Epoch 93/100\n",
      "11/11 [==============================] - 1s 98ms/step - loss: 0.2973 - acc: 0.9002\n",
      "Epoch 94/100\n",
      "11/11 [==============================] - 1s 102ms/step - loss: 0.2487 - acc: 0.9245\n",
      "Epoch 95/100\n",
      "11/11 [==============================] - 1s 103ms/step - loss: 0.1924 - acc: 0.9294\n",
      "Epoch 96/100\n",
      "11/11 [==============================] - 1s 101ms/step - loss: 0.3887 - acc: 0.9009\n",
      "Epoch 97/100\n",
      "11/11 [==============================] - 1s 109ms/step - loss: 0.3643 - acc: 0.9045\n",
      "Epoch 98/100\n",
      "11/11 [==============================] - 1s 100ms/step - loss: 0.3606 - acc: 0.9159\n",
      "Epoch 99/100\n",
      "11/11 [==============================] - 1s 98ms/step - loss: 0.4588 - acc: 0.8867\n",
      "Epoch 100/100\n",
      "11/11 [==============================] - 1s 99ms/step - loss: 0.5343 - acc: 0.9173\n",
      "\n",
      "\n",
      "done.\n",
      "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n"
     ]
    }
   ],
   "source": [
    "learner.lr_find(max_epochs=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXhU1f3H8fc3k40sBEJC2AJhCbIJogFRtChYirWirbtd3Kr117rULtZuLq3WtmoXK9q617bWvRarFZWCCggSkH0NYV9D2LMv5/fHDDFgAgnk5s5kPq/nmYd775yZ+ebWzmfuPfeeY845REQkesX4XYCIiPhLQSAiEuUUBCIiUU5BICIS5RQEIiJRTkEgIhLlYv0uoLkyMjJcTk6O32WIiESUefPm7XTOZTb0XMQFQU5ODvn5+X6XISISUcxsfWPP6dSQiEiUUxCIiEQ5BYGISJRTEIiIRDkFgYhIlFMQiIhEOQVBE23fV86SzXuP2q621rFk815KKqpboSoRkeMXcfcR+ME5x9efmsOq7Qfok5lMRkoC9104hNys1EPaPDBlJY9OXwNAz/Qkrj+zN2MHZtG9Q7vj+vzyqhrKKmvokBSHc2AGZnZc7ykicpBF2sQ0eXl57lhuKFu1fT8zVu/k2jN6N+t1zjn+PmcDP399yWee65yawMje6SzevJf1xaV124f37MDesioKi0rqtiXHB7hiZE8S4wK8t3w7VTW1XH9mHy44qTvt4gOfee8V2/ZRWFTCtr3lPDKtgF0llZ9pM35QFuuKSyipqGFYdhobd5WRmZrAPRMHk52eBMDHa3cxdcV29pVVM2XpNuICxqV52dx2Tn9iYhQmItHCzOY55/IafC5aguDJDwu5983lfHj72WSnJ7GntJK0dnHsK6+mttZhBvPW72bsgM6YGVv2lPHdFxcwf/1uqmsdPdOTePu7Z5IYG2DWmmLunLyk7ou+Y1Icl+Zl075dHNef2Yf42Bhqah2fbNjN3HW7mbtuF/9bsaPBulITYxndN4OZBTsBKK+uIcaMiuraujYDuqSy80Al5VU1jO7XifTkBF7/ZDMV1TXkZCRTWFRCUnyADu3i2LqvnJSEWG4c05eyyhoemVZQ9z5nnZDJh6t3UlPr6NI+kQuHd2dXSQVb95ZjZkwc1o2uaYlkd0wiELDjPpIRkfChIAAKiw4w9qH3geAv9k827Gm07ZeGdiU+EMNrn2wGYEz/TB796skkJ3x6Js05x96yKsqramkXHyCtXdwRP/9g+wMV1XTv0I6aWsdTM9by2Ptr2FNaRUpCLM45SipraJ8YS15OOl8YnMWALu0Z2iOtwVNBzjnMjPKqGpyDdvEB1u4s4frn8inYcQCA9omxPHfdqfTPSiEpPrbuCOfZmWtZEwqy9OT4zxxxBGKM0/p04qwTMrnujN5U1zo27S4ju2M7YgPqWhKJNAqCkHveWMozM9cdsi0uYFTVOLp3aEetc2zdW173XNe0RN64+QwyUhKOp+Sjqq11xMQYzjlWbNtP/6xUAsdx2qa0sprZhcV0ad+O/lkpDX5xO+d4b/kO4gLGWSd0Zk9pJTMLitlVWslbi7bSvl0sU5Zur2ufmhjL/vJgB3hqQizl1TWkJsbxhcFdyExNYFSfdAZ2aU/H5PhjrltEvKMgCHHOUVXjeHHuBvpmppCblUpmasIhz+8ureIXbyylXXyAn543iJSE6O1PL6+q4bmP1vHh6p2s2XGAId3TqA0dhcxbv/szRxEpCbFcOzqHvp1TGNM/kw5JCgWRcKEgEM+UVdYweeFmNuwq5dmZ6yiprAEgO70dj331FMxgUNf2FO2vIDM1QVc7ifhEQSCtYsf+cnbsq2DZ1n386q3l7CmtOuT5cQM688crhkf1UZaIX44UBPp/pLSYzqmJdE5NZEj3NIZ0S+OpGWtZtGkPndsnkBgbYPqqIk67fyo3nNmHm8fl+l2uiIToiEBazbz1u/j9u6uZUbCTnE5JdElL5Ilv5JGaeOQrrkTk+B3piEDXAUqrOaVXOs9eM4Lzh3VjXXEpswt3MfSed/jDe6v8Lk0kqnkaBGY2wcxWmlmBmd3RSJtLzWyZmS01s+e9rEf8FxuI4U9XDGfx3eP51ZdP5OwTOvOH91YzfWXDN9yJiPc8CwIzCwCTgHOBQcAVZjbosDa5wI+B0c65wcB3vapHwktqYhxXntqTR796Mr0zkvnBy4vYsb/86C8UkRbn5RHBSKDAOVfonKsEXgAuOKzN9cAk59xuAOecfhZGmcS4AA9fPpzSymoufGQma4oO+F2SSNTxMgi6AxvrrW8KbauvP9DfzGaa2Wwzm9DQG5nZDWaWb2b5RUVFHpUrfjmxRxr/+Oap7K+o5qf/WkxVTe3RXyQiLcbvzuJYIBc4C7gCeMLMOhzeyDn3uHMuzzmXl5mZ2colSmsY3rMjP/3iQGYX7uLaZ+dSWxtZV7OJRDIvg2AzkF1vvUdoW32bgMnOuSrn3FpgFcFgkCh0+ciejBvQmQ9X72TAnW+zfZ/6DERag5dBMBfINbPeZhYPXA5MPqzN6wSPBjCzDIKnigo9rEnC3CNXnsyEwV2orK7l6mfmKgxEWoFnQeCcqwZuAqYAy4GXnHNLzewXZjYx1GwKUGxmy4BpwA+dc8Ve1SThr118gD9//RSeuWYEa3ce4MonZrN1b5nfZYm0abqzWMLWjNU7+cbTc0iKj2Xq98eQ1T7R75JEIpbuLJaIdEZuBs9deyrlVTX84OWFuppIxCMKAglrZ+RmcN+Xh/Dh6p385LXF1OhqIpEWp9FHJexdNqInG3aVMmnaGjbtLuOJq/I0lLVIC9IRgUSEH4w/gatPz2H22mIufmwWpZXVfpck0mYoCCQimBl3TxzMU1flsWLbfv743mq/SxJpMxQEElHGDsjiklN68MzMdWzbq3sMRFqCgkAizi3jcqlxjrsnL1UYiLQABYFEnOz0JK4c2ZO3l25jzAPTeDl/49FfJCKNUhBIRLrr/EE8/81TGdajA3dNXqq5DESOg4JAIlJsIIbT+2Xwm4uHUlFdy6/fWkGk3SUvEi4UBBLRemck8/VRvXjtk828sWir3+WIRCQFgUS8n39pEEN7pPHL/yyjpEL3F4g0l4JAIl4gxrjr/EEUH6jg568v8bsckYijIJA24ZRe6dx0dj9e+2Qzf3l/jd/liEQUBYG0Gd8+ux9n5mZw/39XMH/Dbr/LEYkYCgJpMxLjAjz61ZNJjg/wzMx1fpcjEjEUBNKmpCbGcc3o3ryxcAvf/Gu+hq0WaQIFgbQ5t4zL5drRvXlv+XYmL9zsdzkiYU9BIG1OfGwMPztvIAO6pPKn/xXoqEDkKBQE0ibFxBi3jsulsKiE5+es97sckbCmIJA2a8KQLpzWpxN/nKqjApEjURBIm2VmfG1UL3YeqOCVeRqhVKQxCgJp08YN7EzP9CQe1lGBSKMUBNKmJcYFuOPcAWzeU8YHq4r8LkckLCkIpM37/KAsMlMT+Om/FrNw4x6/yxEJOwoCafPiAjHcM3Ewu0ur+NGri6iqqfW7JJGwoiCQqPDFE7vym4uHsmLbfiYv2OJ3OSJhRUEgUeP8oV3J7ZzC32brvgKR+hQEEjXMjAuHd2fBxj0s2qS+ApGDFAQSVc4d0gWA219Z5HMlIuFDQSBRpU9mCrdPOIEV2/azobjU73JEwoKCQKLOuUO6AmhkUpEQT4PAzCaY2UozKzCzOxp4/mozKzKzBaHHN72sRwSgd0YyZ+Zm8OysdWzfV+53OSK+8ywIzCwATALOBQYBV5jZoAaavuicOyn0eNKrekTqu2VcLjsPVHLXv5f6XYqI77w8IhgJFDjnCp1zlcALwAUefp5Ik43ISefq03P438od7C2r8rscEV95GQTdgfpDPm4KbTvcRWa2yMxeMbPsht7IzG4ws3wzyy8q0ngx0jIuHN6dyupa3lioG8wkuvndWfwGkOOcGwq8C/y1oUbOucedc3nOubzMzMxWLVDarmE90hiW3YGH3lnJ5j1lfpcj4hsvg2AzUP8Xfo/QtjrOuWLnXEVo9UngFA/rETmEmXHfhUM4UFHN0zPW+l2OiG+8DIK5QK6Z9TazeOByYHL9BmbWtd7qRGC5h/WIfMaQ7mkM79mRp2asZdNu3Vcg0cmzIHDOVQM3AVMIfsG/5Jxbama/MLOJoWa3mNlSM1sI3AJc7VU9Io353uf7A/DPjzf4XImIP8y5yJq1KS8vz+Xn5/tdhrQx33j6Y9buPMAHPzwbM/O7HJEWZ2bznHN5DT3nd2exSFiYOKwbG3eV8YkmrpEopCAQAb4wOIv42BjNVSBRSUEgAqQmxvG53AymrdzhdykirU5BIBIysnc664tLef0TDUYn0UVBIBJy8SnB217u/PcSivZXHKW1SNuhIBAJSU+O57Vvn87+imr+/P4av8sRaTUKApF6Tu7ZkUtO6cFfZ61ji4adkCihIBA5zHfO7keMGQ++s9LvUkRahYJA5DC9OiVzSV4P/rNoK8UH1FcgbZ+CQKQB14zOobK6lhfmbjx6Y5EIpyAQaUC/zqmckJXK7MJiv0sR8ZyCQKQRo/qkk79uN+VVNX6XIuIpBYFII8YOzKKsqoYHp6jTWNo2BYFII07v2wmAJ2esZac6jaUNUxCINCIuEMPLN55GfCCGn7++xO9yRDyjIBA5ghE56Vx5ak+mLN3GrIKdfpcj4gkFgchR3HZOfzqnJvL4h4V+lyLiCQWByFGkJcUxdmBn5q/fTaTN6CfSFAoCkSYY2j2NfeXVLN+63+9SRFqcgkCkCc4ZlEVSfIBJ0wv8LkWkxSkIRJogIyWBS/Oy+e/irSzapHmNpW1REIg00ddG9SIlIZZrnplLZXWt3+WItBgFgUgT9eucwr1fPpHikkpWbVdfgbQdCgKRZhie3QGAmbqnQNoQBYFIM2SnJzEipyOPf1BIdY1OD0nboCAQaaZrRvemuKSS/PW7/S5FpEUoCESaaUz/TOJjY7jz30t0g5m0CQoCkWZKTojlC4O7sGr7AT5ao4lrJPIpCESOwW8vGkp8IIY3Fm3xuxSR46YgEDkG7eIDXDEymxfmbuSvs9b5XY7IcVEQiByj2ycMoFtaO+6avJTt+8r9LkfkmCkIRI5RckIsz1wzAoB3lm7zuRqRY9ekIDCzW82svQU9ZWbzzWx8E143wcxWmlmBmd1xhHYXmZkzs7zmFC/it9zOKfTJTOa/SxQEErmaekRwrXNuHzAe6Ah8Hfj1kV5gZgFgEnAuMAi4wswGNdAuFbgVmNOMukXCgplx7pAuzFm7iz2llX6XI3JMmhoEFvr3i8DfnHNL621rzEigwDlX6JyrBF4ALmig3S+B3wA6ySoR6XO5mdTUOuau0w1mEpmaGgTzzOwdgkEwJfQr/mj313cHNtZb3xTaVsfMTgaynXNvHumNzOwGM8s3s/yioqImlizSOoZldyA+EMPHa3VPgUSm2Ca2uw44CSh0zpWaWTpwzfF8sJnFAL8Drj5aW+fc48DjAHl5ebqVU8JKYlyAk3p2YJZuLpMI1dQjgtOAlc65PWb2NeBnwN6jvGYzkF1vvUdo20GpwBBgupmtA0YBk9VhLJFoTP9Mlm7Zx/riEr9LEWm2pgbBY0CpmQ0Dvg+sAZ47ymvmArlm1tvM4oHLgckHn3TO7XXOZTjncpxzOcBsYKJzLr+5f4SI3847sStxAWPSNE1lKZGnqUFQ7YKja10APOKcm0TwF32jnHPVwE3AFGA58JJzbqmZ/cLMJh5P0SLhJicjmfOHdePleZtYuU2T1khkaWoQ7DezHxO8bPTN0Pn9uKO9yDn3lnOuv3Our3PuvtC2O51zkxtoe5aOBiSS/fy8QcTGGP/8eIPfpYg0S1OD4DKgguD9BNsInu9/wLOqRCJQx+R4Jgzpysv5G6mt1TUNEjmaFAShL/9/AGlm9iWg3Dl3tD4CkahzRr9OlFTW8MSHhX6XItJkTR1i4lLgY+AS4FJgjpld7GVhIpHoS0O7AfD+Kt3vIpGjqaeGfgqMcM5d5Zz7BsG7hn/uXVkikSk5IZYbx/Rl1ppinpqx1u9yRJqkqUEQ45zbUW+9uBmvFYkqN47pQ2yM8YI6jSVCNPXL/G0zm2JmV5vZ1cCbwFvelSUSuTokxfPjLw5k9Y4DbCgu9bsckaNqamfxDwkO8TA09HjcOfcjLwsTiWTjBnQGUKexRISmjjWEc+5V4FUPaxFpM3IykjmjXwbTV+04emMRnx3xiMDM9pvZvgYe+81sX2sVKRKJPj8oi427yvhIg9FJmDtiEDjnUp1z7Rt4pDrn2rdWkSKR6LIR2STHB/j3gs1HbyziI135I+KRxLgAE4Z05bVPNlO0v8LvckQapSAQ8dA1o3OorK7lveXb/S5FpFEKAhEPDe7Wnl6dkvjxa4vZVaI5jSU8KQhEPGRmXJoXnJ/p4amrfa5GIlV5VQ2zCnZ69v4KAhGP/d+Yvpx1QibTVupSUjk2T3xQyJVPzvFsDCsFgYjHYmKM0/t2Yn1xqTqN5Zhs2VsOwL6yKk/eX0Eg0gpO6ZUOoE5jOSZ7SivpmxmcBc8LCgKRVnBSdgcGdEnlbx+t97sUiUAHKqpJTTzqpJDHTEEg0goCMcYledks27qP2YW601iap6SimpSEJo8I1GwKApFWcvEpPeiTkcz//X0e5VU1fpcjEeRARTXJCQHP3l9BINJK0trF8aNzB7C7tIolm/f6XY5EkJKKGpJ1RCDSNuT16kggxnhz8Va/S5EIsr+8SqeGRNqKTikJXDCsG/+Ys4H95d5cCihty57SSvaVV9OjYzvPPkNBINLKvjqqF5XVtUyatsbvUiQCrAvNctcnI8Wzz1AQiLSyk3t2AODP76/RVJZyVJt3lwGQkZrg2WcoCERamZnx9NV5ALyrG8zkKL7z/HwA0pPiPfsMBYGID8YOyGJg1/Y8PWOthp2QJumQrBvKRNqcX315CNv2lfP0zLV+lyIRIFVXDYm0PcN7dmRUn3Qem76GwqIDfpcjYSo9OZ6vj+qFmXn2GQoCER9dd0ZvAJ74sNDnSiRclVRUk+ThXcWgIBDx1dgBWZwzsDMfrdH4Q/JZH64uoqK6luR4704LgYJAxHen981gXXEpz6qvQOpZuHEPX3/qYwBP7yoGj4PAzCaY2UozKzCzOxp4/kYzW2xmC8xshpkN8rIekXB05ak9yemUxGufbPa7FAkTzjlue2lB3bqX9xCAh0FgZgFgEnAuMAi4ooEv+uedcyc6504Cfgv8zqt6RMJVYlyAc0/syrIt+9itCe4F2FdWTWFRSd16SgT3EYwECpxzhc65SuAF4IL6DZxz++qtJgPOw3pEwtYFJ3Wj1jl+9Ooiamv1f4NoV1wSvLfkZ+cN5Afj+zOmf2dPP8/LIOgObKy3vim07RBm9h0zW0PwiOCWht7IzG4ws3wzyy8q8mbyZhE/DejSnm+N6cs7y7YzSx3HUW93afDIMDcrlZvG5hKI8e7SUQiDzmLn3CTnXF/gR8DPGmnzuHMuzzmXl5mZ2boFirSSm8f2I6t9An9+X4PRRbviA8Eg8HJYifq8DILNQHa99R6hbY15AbjQw3pEwlpSfCwXntSdOWuLqajWDGbR7OARQXpK5AfBXCDXzHqbWTxwOTC5fgMzy623eh6w2sN6RMLe8J4dqapxTF2+w+9SxEcHx5+K+CMC51w1cBMwBVgOvOScW2pmvzCziaFmN5nZUjNbAHwPuMqrekQiwbiBncntnMJfdHooqj34zioA2sV7e7XQQZ7epeCcewt467Btd9ZbvtXLzxeJNHGBGMYPzmLStDXMLNjJ6H4ZfpckrezgjYWxHncQ1+d7Z7GIHOqyvJ4APPTOSl1KGmWcc9z9xjIAnr56RKt9roJAJMz07JTEnV8axPwNe3juo3V+lyOtqKK6tm65S1piq32ugkAkDF0zOofB3drzUv4mnNNRQbTYX15dt9w7I7nVPldBIBKGzIwrT+3Jsq37mL5KN1FGi39+vAGA3182jLhA6309KwhEwtQlp2TTLS2Ru/69lL2lVX6XIx4qPlDB915cwO/eDV4tFOPhJDQNURCIhKn42Bi+P/4ENuwq1XSWbVhtrWPU/VMPGX22tc8GKghEwthFp/TgxO5pPPlhITv2l/tdjnjgnWXbqar59Jv/6tNzOH9Yt1atQUEgEuZ+eeEQSiprmLxgi9+liAcKduyvW/7bdSO5e+JgzweZO5yCQCTMDeuRRnpyPPe+uZziAxV+lyMtbMvectKT41nxywmcmevPoJoKApEwZ2b8YPwJAMwo2OlzNdLSFm3aQ/+sFBLjWmc4iYYoCEQiwGUjsuneoR0vfLzx6I0lYtTWOlZu28+w7A6+1qEgEIkAgRjja6N68VFhMW8t3up3OdJCduyvoKrGkd0xydc6FAQiEeLyEdkEYowfv7aY8irNV9AWPD9nPQC9OikIRKQJOibHc8/Ewewtq+KHryzyuxxpAQ//rwBAp4ZEpOnOHdIFgDcWbmHBxj0+VyPHKyE2hlF90mmfGOdrHQoCkQjSKSWBaT84C4CX8tVxHMnW7iyhorqWcwZm+V2KgkAk0vTOSGbsgM7MXlPsdylyHB6cshKAUX06+VyJgkAkIp3WpxOFO0vYtlfDTkSqZVv30SczmSHd0/wuRUEgEonOyA1OYXn6r6eyobjU52qkuXaXVLJ2ZwmXnJLtdymAgkAkIg3s2p5vjelDrYN731ymyWsizIZdwfDu1znF50qCFAQiEerH5w7k5rH9eGfZdt5YpJvMIsmesuD8EunJ/l4tdJCCQCSC3Toul8zUBG755ye6iiiC7CoJDh6Y1i7e50qCFAQiESw2EMPvLz0JgNtfWcQbCzVUdST4yWtLAOiQpCMCEWkBZ+Rm8OevnQzA919eyIpt+3yuSI4mLhCcb6BTso4IRKSFTBjSlVl3jKWqppbHPyikolpjEYWbd5dtZ/6G3QB069CO8YOysFaem7gxCgKRNqJbh3YM7dGB1+Zv5oSfvc364hK/S5KQxZv2cv1z+Xzl0VnU1DpKKqtJSYj1u6w6CgKRNuQn5w6oW77osVns2KcbzvxUWV3LtBU7eOjdlXXb3l+1g9KKGpIS/JuI5nAKApE25NQ+nZj703M4vW8ndh6oZOSvpvKbt1f4XVbUmjStgGuencv0lUV1216cu5GSymqS43VEICIeyUxN4K/XjuTJb+Qxul8nnp6xVvMX+OS1TzbVLf/u0mFcfXoOU5Zup7yqlrQwuWIIFAQibVJcIIZzBmXxnbP7UVFdy9kPTtfdx61sTmExG3eVAfDmLWfwlZN78KWhXeueP39oN79K+wwFgUgbdnrfDLq0T2Tr3nJ6//gtbnp+vgKhFfzkX4u57PHZdeuDuwUHlhuW3YHYGKN3RjLZ6f7OSlafgkCkjfvXd06vW/7Poq0s3aL7DLzknOP5ORsAyEiJZ+YdY+ueiwvEUPCrL9bNKREuPA0CM5tgZivNrMDM7mjg+e+Z2TIzW2RmU82sl5f1iESjrmntuGVsP84Z2BmAf32y2eeK2rY1RZ9etvuvb4+me4d2PlbTNJ4FgZkFgEnAucAg4AozG3RYs0+APOfcUOAV4Lde1SMSzb43/gSevGoE5wzM4qkZa7nmmY9105kH9pZVcc7v3gdgxo/ODqvTP0fi5RHBSKDAOVfonKsEXgAuqN/AOTfNOXdwMPXZQA8P6xGJereOy6VDUhzTVhbxh/dW+11Om/P32esBOKVXR3p0jIwQAG+DoDtQfzjETaFtjbkO+G9DT5jZDWaWb2b5RUVFDTURkSY4sUcaC+4cz3knduWx6WsYdOfbXDhpJgU79vtdWptwcJKgJ7+R53MlzRMWncVm9jUgD3igoeedc4875/Kcc3mZmZmtW5xIG/Tri04EoLSyhgUb93DFE3PYuEsznR0L51zdlVjFJZUM6JJKxzAZTK6pvAyCzUD9edh6hLYdwszOAX4KTHTOVXhYj4iEpCbGkRuaHeu6M3pTtL+CM387jeufy/e5ssjzpT/N4Of/XkJldS3vLd9OVU2t3yU1m5f3OM8Fcs2sN8EAuBy4sn4DMxsO/AWY4Jzb4WEtInKYF791GvvKqsjJSCY2YPzl/ULeXbadd5dt5/ODsvwuLyKs2LaPpVuCj5fmBu8i7pSc4HNVzefZEYFzrhq4CZgCLAdecs4tNbNfmNnEULMHgBTgZTNbYGaTvapHRA6VnhxPTkYyEJz28t/fGQ3A9c/lU3xAB+dN8cU/fli3XBk6EnjwkmF+lXPMPO0jcM695Zzr75zr65y7L7TtTufc5NDyOc65LOfcSaHHxCO/o4h4ZVh2B565ZgQAj39Q6HM14a+ssoba0E3az19/at32Hh3D/76Bw4XP8Hci4ruz+mcysnc6f/mgkOSEWG4e2y9sJk8JN3+bvQ6A564dyel9M7hiZE/KKquJiYm8/WWRNu5IXl6ey89Xh5aIVzbuKuXM306rW//j5SdxwUlHuvI7+kyaVsADU1Yysnc6z3/zVGIDYXEB5hGZ2TznXIPXtYZ/9SLSqrLTk5jzk3F163e8urjuEslI++HohcrqWh6YEpxo5sGLh0VECByNTg2JyGdktU9k+S8m8Nj7a3h46mrGPDCdDaH7DGbdMZZuRxg/Z8ueMiYv3EJZZQ23fb5/a5XcarbuDQ4t/ZXh3enZKXLuHj4SBYGINKhdfIBbxvbj4amr60IA4L63ljOqdzq9M1Lo2zmZrmnBUNhXXkV5ZQ3f+ts8Fm/eCwQ7Ts/IzaBzaiKBCDx3DrC7pJKK6lq6pCUC8M+PgwMmXJzXdkbEUR+BiBzR20u2cuPf5wOQFB+gtPLTweo6JsUx5bbP8eVJs9i8p6zR9/jK8O787rKTPK+1pe3YV86Zv52Gc8ErgzomxzPuoeCgcivvnUBCbPjMO3w0R+ojUBCISJPNKtjJlU/OOWKb+EAM799+FsUHKrnvzeV8VFgMwI1j+nLHuQNao8xjkr9uF49MK+D2Lwygd0Yy7eID/ObtFTw2fU2D7df9+rxWrvD4KAhEpMWUV9WQEBvDntIqhv/y3brtd58/iMtH9iQx7tBfyau37+fzv/+gbr1rWiLvfW8MyQnNP3t0VcMAAArYSURBVDO9aXcpWe0TifOgg7b/z/5LZXXwprBTe6czZ+0uAM7MzSApPsCUpdvr2r70rdMY2Tu9xWvw0pGCQH0EItIsB7/oOybHs+7X53Ggopp2cYFG+wD6dU5h3IDOTF0RHEVm695yBt81hTH9M7lweDcmDuvepP6Dl/I3cvsriwB473ufo1/n1Bb6i4KngA6GAFAXAgA3j81lZO90Ji/cQq/0JAZ3a98mrhSqr239NSLS6lISYo/4RW5mPHX1CM4b2pXT+nRiYNf2ALy/qojbXlzIhD98UDd880HVNbXMWL2TfeVVOOd48sPCuhAAOOd3H7Bw454m17h4016mLv/0F/3esqpDnp+/Ifher337dL5zdt+67bdPOIEROR0BmDisW3DO4TYWAqBTQyLSynaVVLJ6+34enb6G91d9Or/I/V85kTNzM5hduIsfvLwQgCtG9iQzNYGHpwYn0fnrtSOZvnIHz8xcBwRvdsvplMyw7A6f+Zyte8v458cbiQ8YD76zqm57/6wUVm0/wG8vGsro3AyS4wM8Nn0Nz8xcx+J7xgOwcONeTuyeRrv4yOkMPhr1EYhI2HronZX86X8FR2331FV5jBuYRU2tY+ry7dzwt3l1z33ww7Prrukvr6rhh68s4o2FW5r0+R2T4thdWsUJWalMue1zx/ZHRAD1EYhI2Lp5bC5b9pTz6vxNddv6Z6Xwx8uHc25odM8/f+1kxg0MDo0diDHGD+7CreNyeWrGWg5UVPO5B6aRnhzP8OwOOOB/Kz4d1T42xnj++lHsKa1k3obdvLtsO+MHdeHpGWuprKlld2nwNFGkzC/sBR0RiEhYKKusobyqhveWb+fLw7sTG4hhT2klKQmxRzwvP23FDq55du4h23I6JfHVU3uRnBDLV07u/pkrmQBKKqrZX17NqPunAjD/558nPcJmFmsOnRoSkTbt3ws289xH69m6p4z95dU8cMkwJgzp0qTXLti4h24dEumcmuhxlf7SqSERadMuOKn7MY+QelIDHc3Rpu1dByUiIs2iIBARiXIKAhGRKKcgEBGJcgoCEZEopyAQEYlyCgIRkSinIBARiXIRd2exmRUB6/2uo4VkADv9LiLCaJ8dG+235mtr+6yXcy6zoSciLgjaEjPLb+yWb2mY9tmx0X5rvmjaZzo1JCIS5RQEIiJRTkHgr8f9LiACaZ8dG+235ouafaY+AhGRKKcjAhGRKKcgEBGJcgoCEZEopyAIU2YWY2b3mdmfzOwqv+uJFGaWbGb5ZvYlv2uJBGZ2oZk9YWYvmtl4v+sJV6H/rv4a2ldf9buelqYg8ICZPW1mO8xsyWHbJ5jZSjMrMLM7jvI2FwA9gCpgk1e1hosW2mcAPwJe8qbK8NIS+8w597pz7nrgRuAyL+sNN83cf18BXgntq4mtXqzHdNWQB8zsc8AB4Dnn3JDQtgCwCvg8wS/2ucAVQAC4/7C3uDb02O2c+4uZveKcu7i16vdDC+2zYUAnIBHY6Zz7T+tU74+W2GfOuR2h1z0E/MM5N7+VyvddM/ffBcB/nXMLzOx559yVPpXtCU1e7wHn3AdmlnPY5pFAgXOuEMDMXgAucM7dD3zmNIaZbQIqQ6s13lUbHlpon50FJAODgDIze8s5V+tl3X5qoX1mwK8JfslFTQhA8/YfwVDoASygDZ5JURC0nu7Axnrrm4BTj9D+NeBPZnYm8IGXhYWxZu0z59xPAczsaoJHBG02BI6guf+d3QycA6SZWT/n3J+9LC4CNLb/HgYeMbPzgDf8KMxLCoIw5ZwrBa7zu45I5Jx71u8aIoVz7mGCX3JyBM65EuAav+vwSps7xAljm4Hseus9Qtukcdpnzad9dnyicv8pCFrPXCDXzHqbWTxwOTDZ55rCnfZZ82mfHZ+o3H8KAg+Y2T+Bj4ATzGyTmV3nnKsGbgKmAMuBl5xzS/2sM5xonzWf9tnx0f77lC4fFRGJcjoiEBGJcgoCEZEopyAQEYlyCgIRkSinIBARiXIKAhGRKKcgEM+Z2YFW+IyJTRymuiU/8ywzO/0YXjfczJ4KLV9tZo+0fHXNZ2Y5hw/J3ECbTDN7u7VqktahIJCIERoiuEHOucnOuV978JlHGo/rLKDZQQD8hAgd38c5VwRsNbPRftciLUdBIK3KzH5oZnPNbJGZ3VNv++tmNs/MlprZDfW2HzCzh8xsIXCama0zs3vMbL6ZLTazAaF2db+szexZM3vYzGaZWaGZXRzaHmNmj5rZCjN718zeOvjcYTVON7M/mFk+cKuZnW9mc8zsEzN7z8yyQsMX3wjcZmYLzOzM0K/lV0N/39yGvizNLBUY6pxb2MBzOWb2v9C+mWpmPUPb+5rZ7NDfe29DR1gWnEHrTTNbaGZLzOyy0PYRof2w0Mw+NrPU0Od8GNqH8xs6qjGzgJk9UO9/q2/Ve/p1oM3N0hXVnHN66OHpAzgQ+nc88DhgBH+E/Af4XOi59NC/7YAlQKfQugMurfde64CbQ8vfBp4MLV8NPBJafhZ4OfQZgwiOLw9wMfBWaHsXYDdwcQP1TgcerbfekU/vwv8m8FBo+W7gB/XaPQ+cEVruCSxv4L3PBl6tt16/7jeAq0LL1wKvh5b/A1wRWr7x4P487H0vAp6ot54GxAOFwIjQtvYERxxOAhJD23KB/NByDrAktHwD8LPQcgKQD/QOrXcHFvv935UeLffQMNTSmsaHHp+E1lMIfhF9ANxiZl8Obc8ObS8mOCnPq4e9z2uhf+cRnEKwIa+74HwEy8wsK7TtDODl0PZtZjbtCLW+WG+5B/CimXUl+OW6tpHXnAMMCs71AkB7M0txztX/Bd8VKGrk9afV+3v+Bvy23vYLQ8vPAw828NrFwENm9hvgP865D83sRGCrc24ugHNuHwSPHgiOrX8Swf3bv4H3Gw8MrXfElEbwf5O1wA6gWyN/g0QgBYG0JgPud8795ZCNwZnFzgFOc86Vmtl0gtNNApQ75w6foa0i9G8Njf83XFFv2RppcyQl9Zb/BPzOOTc5VOvdjbwmBhjlnCs/wvuW8enf1mKcc6vM7GTgi8C9ZjYV+FcjzW8DthOc2jMGaKheI3jkNaWB5xIJ/h3SRqiPQFrTFOBaM0sBMLPuZtaZ4K/N3aEQGACM8ujzZwIXhfoKsgh29jZFGp+OSX9Vve37gdR66+8QnPELgNAv7sMtB/o18jmzCA57DMFz8B+GlmcTPPVDvecPYWbdgFLn3N+BB4CTgZVAVzMbEWqTGur8TiN4pFALfJ3gfMaHmwL8n5nFhV7bP3QkAcEjiCNeXSSRRUEgrcY59w7BUxsfmdli4BWCX6RvA7Fmtpzg/LmzPSrhVYJTDy4D/g7MB/Y24XV3Ay+b2TxgZ73tbwBfPthZDNwC5IU6V5cRPJ9/COfcCoLTQqYe/hzBELnGzBYR/IK+NbT9u8D3Qtv7NVLzicDHZrYAuAu41zlXCVxGcMrThcC7BH/NPwpcFdo2gEOPfg56kuB+mh+6pPQvfHr0dTbwZgOvkQilYaglqhw8Z29mnYCPgdHOuW2tXMNtwH7n3JNNbJ8ElDnnnJldTrDj+AJPizxyPR8AFzjndvtVg7Qs9RFItPmPmXUg2On7y9YOgZDHgEua0f4Ugp27BuwheEWRL8wsk2B/iUKgDdERgYhIlFMfgYhIlFMQiIhEOQWBiEiUUxCIiEQ5BYGISJRTEIiIRLn/BzSl/aKAfFBDAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learner.lr_plot()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### STEP 4: Train the Model\n",
    "We will train the model using `autofit`, which uses a triangular learning rate policy.  We will save the weights for each epoch so that we can reload the best weights when training completes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "begin training using triangular learning rate policy with max lr of 0.005...\n",
      "Epoch 1/30\n",
      "12/12 [==============================] - 9s 748ms/step - loss: 0.6526 - acc: 0.5750 - val_loss: 0.4119 - val_acc: 0.8907\n",
      "Epoch 2/30\n",
      "12/12 [==============================] - 7s 578ms/step - loss: 0.4043 - acc: 0.8889 - val_loss: 0.3422 - val_acc: 0.8907\n",
      "Epoch 3/30\n",
      "12/12 [==============================] - 8s 674ms/step - loss: 0.3405 - acc: 0.8889 - val_loss: 0.3063 - val_acc: 0.8907\n",
      "Epoch 4/30\n",
      "12/12 [==============================] - 8s 644ms/step - loss: 0.3079 - acc: 0.8889 - val_loss: 0.2740 - val_acc: 0.8907\n",
      "Epoch 5/30\n",
      "12/12 [==============================] - 7s 620ms/step - loss: 0.2774 - acc: 0.8889 - val_loss: 0.2527 - val_acc: 0.8907\n",
      "Epoch 6/30\n",
      "12/12 [==============================] - 8s 662ms/step - loss: 0.2488 - acc: 0.8902 - val_loss: 0.2311 - val_acc: 0.9193\n",
      "Epoch 7/30\n",
      "12/12 [==============================] - 7s 614ms/step - loss: 0.2264 - acc: 0.9118 - val_loss: 0.2206 - val_acc: 0.9243\n",
      "Epoch 8/30\n",
      "12/12 [==============================] - 8s 667ms/step - loss: 0.2084 - acc: 0.9347 - val_loss: 0.2185 - val_acc: 0.9158\n",
      "Epoch 9/30\n",
      "12/12 [==============================] - 8s 640ms/step - loss: 0.2012 - acc: 0.9380 - val_loss: 0.2146 - val_acc: 0.9188\n",
      "Epoch 10/30\n",
      "12/12 [==============================] - 7s 621ms/step - loss: 0.2021 - acc: 0.9301 - val_loss: 0.2158 - val_acc: 0.9240\n",
      "Epoch 11/30\n",
      "12/12 [==============================] - 8s 675ms/step - loss: 0.1944 - acc: 0.9367 - val_loss: 0.2254 - val_acc: 0.9101\n",
      "Epoch 12/30\n",
      "12/12 [==============================] - 8s 673ms/step - loss: 0.1904 - acc: 0.9341 - val_loss: 0.2141 - val_acc: 0.9207\n",
      "Epoch 13/30\n",
      "12/12 [==============================] - 7s 622ms/step - loss: 0.1785 - acc: 0.9406 - val_loss: 0.2164 - val_acc: 0.9195\n",
      "Epoch 14/30\n",
      "12/12 [==============================] - 7s 621ms/step - loss: 0.1702 - acc: 0.9432 - val_loss: 0.2188 - val_acc: 0.9200\n",
      "Epoch 15/30\n",
      "12/12 [==============================] - 7s 582ms/step - loss: 0.1738 - acc: 0.9399 - val_loss: 0.2218 - val_acc: 0.9177\n",
      "Epoch 16/30\n",
      "12/12 [==============================] - 7s 580ms/step - loss: 0.1644 - acc: 0.9406 - val_loss: 0.2223 - val_acc: 0.9160\n",
      "Epoch 17/30\n",
      "12/12 [==============================] - 8s 652ms/step - loss: 0.1713 - acc: 0.9341 - val_loss: 0.2230 - val_acc: 0.9167\n",
      "Epoch 18/30\n",
      "12/12 [==============================] - 7s 584ms/step - loss: 0.1687 - acc: 0.9393 - val_loss: 0.2262 - val_acc: 0.9160\n",
      "Epoch 19/30\n",
      "12/12 [==============================] - 7s 620ms/step - loss: 0.1685 - acc: 0.9426 - val_loss: 0.2224 - val_acc: 0.9179\n",
      "Epoch 20/30\n",
      "12/12 [==============================] - 8s 645ms/step - loss: 0.1555 - acc: 0.9439 - val_loss: 0.2420 - val_acc: 0.9042\n",
      "Epoch 21/30\n",
      "12/12 [==============================] - 8s 649ms/step - loss: 0.1730 - acc: 0.9406 - val_loss: 0.2194 - val_acc: 0.9179\n",
      "Epoch 22/30\n",
      "12/12 [==============================] - 7s 572ms/step - loss: 0.1606 - acc: 0.9425 - val_loss: 0.2232 - val_acc: 0.9203\n",
      "Epoch 23/30\n",
      "12/12 [==============================] - 7s 614ms/step - loss: 0.1529 - acc: 0.9504 - val_loss: 0.2301 - val_acc: 0.9096\n",
      "Epoch 24/30\n",
      "12/12 [==============================] - 7s 575ms/step - loss: 0.1606 - acc: 0.9399 - val_loss: 0.2243 - val_acc: 0.9210\n",
      "Epoch 25/30\n",
      "12/12 [==============================] - 7s 588ms/step - loss: 0.1453 - acc: 0.9536 - val_loss: 0.2297 - val_acc: 0.9122\n",
      "Epoch 26/30\n",
      "12/12 [==============================] - 8s 701ms/step - loss: 0.1427 - acc: 0.9517 - val_loss: 0.2262 - val_acc: 0.9165\n",
      "Epoch 27/30\n",
      "12/12 [==============================] - 7s 568ms/step - loss: 0.1424 - acc: 0.9451 - val_loss: 0.2420 - val_acc: 0.9113\n",
      "Epoch 28/30\n",
      "12/12 [==============================] - 8s 632ms/step - loss: 0.1402 - acc: 0.9471 - val_loss: 0.2382 - val_acc: 0.9167\n",
      "Epoch 29/30\n",
      "12/12 [==============================] - 8s 688ms/step - loss: 0.1393 - acc: 0.9530 - val_loss: 0.2369 - val_acc: 0.9103\n",
      "Epoch 30/30\n",
      "12/12 [==============================] - 7s 590ms/step - loss: 0.1352 - acc: 0.9517 - val_loss: 0.2465 - val_acc: 0.9080\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fbc080336d8>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.autofit(0.005, 30, checkpoint_folder='/tmp/saved_weights')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's use the weights from Epoch 12."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.model.load_weights('/tmp/saved_weights/weights-12.hdf5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "     hateful       0.68      0.57      0.62       462\n",
      "      normal       0.95      0.97      0.96      3764\n",
      "\n",
      "    accuracy                           0.92      4226\n",
      "   macro avg       0.82      0.77      0.79      4226\n",
      "weighted avg       0.92      0.92      0.92      4226\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[ 262,  200],\n",
       "       [ 122, 3642]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.validate(class_names=preproc.get_classes())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict\n",
    "\n",
    "Let's make predictions for all Twitter users that are unlabeled (i.e., we don't know whether or not they are hateful)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = ktrain.get_predictor(learner.model, preproc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_unlabeled = preproc.df[preproc.df.target=='unknown']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = p.predict_transductive(df_unlabeled.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = np.array(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df_preds = pd.DataFrame(zip(df_unlabeled.index, preds), columns=['UserID', 'Predicted'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UserID</th>\n",
       "      <th>Predicted</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>52</th>\n",
       "      <td>56</td>\n",
       "      <td>hateful</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>129</th>\n",
       "      <td>140</td>\n",
       "      <td>hateful</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>243</th>\n",
       "      <td>259</td>\n",
       "      <td>hateful</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>475</th>\n",
       "      <td>499</td>\n",
       "      <td>hateful</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>501</th>\n",
       "      <td>526</td>\n",
       "      <td>hateful</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    UserID Predicted\n",
       "52      56   hateful\n",
       "129    140   hateful\n",
       "243    259   hateful\n",
       "475    499   hateful\n",
       "501    526   hateful"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_preds[df_preds.Predicted=='hateful'].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "578"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_preds[df_preds.Predicted=='hateful'].shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Out of over 95,000 unlabeled nodes in the Twitter graph, our model predicted **578** as potential hateful users that would seem to warrant a review."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
